[ADT] Bitset: add shift operators, word accessors, and etc#193400
[ADT] Bitset: add shift operators, word accessors, and etc#193400JiachenYuan wants to merge 1 commit intollvm:mainfrom
Conversation
|
@llvm/pr-subscribers-llvm-adt Author: Jiachen Yuan (JiachenYuan) ChangesThis PR is split out from #191757 per reviewer request. It has the following changes to
A follow-up PR will use these to re-implement The unit test in the PR is largely generated by LLMs. I have reviewed it and manually applied changes to cover more edge cases. Full diff: https://github.com/llvm/llvm-project/pull/193400.diff 2 Files Affected:
diff --git a/llvm/include/llvm/ADT/Bitset.h b/llvm/include/llvm/ADT/Bitset.h
index 9dc0f24b1d9f5..3cb2b7d28d83b 100644
--- a/llvm/include/llvm/ADT/Bitset.h
+++ b/llvm/include/llvm/ADT/Bitset.h
@@ -51,8 +51,9 @@ template <unsigned NumBits> class Bitset {
constexpr void maskLastWord() { Bits[getLastWordIndex()] &= RemainderMask; }
-protected:
- constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
+public:
+ explicit constexpr Bitset(
+ const std::array<uint64_t, (NumBits + 63) / 64> &B) {
if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
for (size_t I = 0; I != B.size(); ++I)
Bits[I] = B[I];
@@ -70,8 +71,6 @@ template <unsigned NumBits> class Bitset {
}
maskLastWord();
}
-
-public:
constexpr Bitset() = default;
constexpr Bitset(std::initializer_list<unsigned> Init) {
for (auto I : Init)
@@ -194,6 +193,92 @@ template <unsigned NumBits> class Bitset {
}
return false;
}
+
+ constexpr Bitset &operator<<=(unsigned N) {
+ if (N == 0)
+ return *this;
+ if (N >= NumBits) {
+ return *this = Bitset();
+ }
+ const unsigned WordShift = N / BitwordBits;
+ const unsigned BitShift = N % BitwordBits;
+ if (BitShift == 0) {
+ for (int I = NumWords - 1; I >= static_cast<int>(WordShift); --I)
+ Bits[I] = Bits[I - WordShift];
+ } else {
+ const unsigned CarryShift = BitwordBits - BitShift;
+ for (int I = NumWords - 1; I > static_cast<int>(WordShift); --I) {
+ Bits[I] = (Bits[I - WordShift] << BitShift) |
+ (Bits[I - WordShift - 1] >> CarryShift);
+ }
+ Bits[WordShift] = Bits[0] << BitShift;
+ }
+ for (unsigned I = 0; I < WordShift; ++I)
+ Bits[I] = 0;
+ maskLastWord();
+ return *this;
+ }
+
+ constexpr Bitset operator<<(unsigned N) const {
+ Bitset Result(*this);
+ Result <<= N;
+ return Result;
+ }
+
+ constexpr Bitset &operator>>=(unsigned N) {
+ if (N == 0)
+ return *this;
+ if (N >= NumBits) {
+ return *this = Bitset();
+ }
+ const unsigned WordShift = N / BitwordBits;
+ const unsigned BitShift = N % BitwordBits;
+ if (BitShift == 0) {
+ for (unsigned I = 0; I < NumWords - WordShift; ++I)
+ Bits[I] = Bits[I + WordShift];
+ } else {
+ const unsigned CarryShift = BitwordBits - BitShift;
+ for (unsigned I = 0; I < NumWords - WordShift - 1; ++I) {
+ Bits[I] = (Bits[I + WordShift] >> BitShift) |
+ (Bits[I + WordShift + 1] << CarryShift);
+ }
+ Bits[NumWords - WordShift - 1] = Bits[NumWords - 1] >> BitShift;
+ }
+ for (unsigned I = NumWords - WordShift; I < NumWords; ++I)
+ Bits[I] = 0;
+ maskLastWord();
+ return *this;
+ }
+
+ constexpr Bitset operator>>(unsigned N) const {
+ Bitset Result(*this);
+ Result >>= N;
+ return Result;
+ }
+
+ /// Return the I-th 64-bit word of the bitset, from least significant to most.
+ constexpr uint64_t getWord(unsigned I) const {
+ if constexpr (BitwordBits == 64) {
+ return Bits[I];
+ } else {
+ static_assert(BitwordBits == 32, "Unsupported word size");
+ uint64_t Lo = (2 * I < NumWords) ? Bits[2 * I] : 0;
+ uint64_t Hi = (2 * I + 1 < NumWords) ? Bits[2 * I + 1] : 0;
+ return Lo | (Hi << 32);
+ }
+ }
+
+ /// Return the index of the highest set bit, or -1 if no bits are set.
+ constexpr int findLastSet() const {
+ for (int I = NumWords - 1; I >= 0; --I)
+ if (Bits[I] != 0)
+ return I * BitwordBits +
+ (BitwordBits - 1 - countl_zero_constexpr(Bits[I]));
+ return -1;
+ }
+
+ /// Return the number of 64-bit words needed to hold all bits.
+ static constexpr unsigned getNumWords() { return (NumBits + 63) / 64; }
};
} // end namespace llvm
diff --git a/llvm/unittests/ADT/BitsetTest.cpp b/llvm/unittests/ADT/BitsetTest.cpp
index 678197e31a379..ee3ef07d01979 100644
--- a/llvm/unittests/ADT/BitsetTest.cpp
+++ b/llvm/unittests/ADT/BitsetTest.cpp
@@ -294,4 +294,202 @@ TEST(BitsetTest, BitwiseOperators) {
TestXor128.test(127));
}
+TEST(BitsetTest, ShiftOperators) {
+ // Test left shift.
+ static_assert((Bitset<64>({0}) << 10).test(10));
+ static_assert(!(Bitset<64>({0}) << 10).test(0));
+ static_assert((Bitset<64>({63}) << 1).none());
+ static_assert((Bitset<128>({0}) << 64).test(64));
+ static_assert((Bitset<128>({63}) << 1).test(64));
+ static_assert((Bitset<128>({127}) << 1).none());
+
+ // Test right shift.
+ static_assert((Bitset<64>({10}) >> 10).test(0));
+ static_assert(!(Bitset<64>({10}) >> 10).test(10));
+ static_assert((Bitset<64>({0}) >> 1).none());
+ static_assert((Bitset<128>({64}) >> 64).test(0));
+ static_assert((Bitset<128>({64}) >> 1).test(63));
+ static_assert((Bitset<128>({0}) >> 1).none());
+
+ // Test shift by 0.
+ static_assert((Bitset<64>({10, 20}) << 0) == Bitset<64>({10, 20}));
+ static_assert((Bitset<64>({10, 20}) >> 0) == Bitset<64>({10, 20}));
+
+ // Test shift by NumBits (clears all).
+ static_assert((Bitset<64>({0, 63}) << 64).none());
+ static_assert((Bitset<64>({0, 63}) >> 64).none());
+ static_assert((Bitset<128>({0, 127}) << 128).none());
+ static_assert((Bitset<128>({0, 127}) >> 128).none());
+}
+
+TEST(BitsetTest, GetNumWords64) {
+ static_assert(Bitset<1>::getNumWords() == 1);
+ static_assert(Bitset<32>::getNumWords() == 1);
+ static_assert(Bitset<64>::getNumWords() == 1);
+ static_assert(Bitset<65>::getNumWords() == 2);
+ static_assert(Bitset<96>::getNumWords() == 2);
+ static_assert(Bitset<128>::getNumWords() == 2);
+ static_assert(Bitset<129>::getNumWords() == 3);
+}
+
+TEST(BitsetTest, GetWord) {
+ // Single-word bitset.
+ constexpr auto B64 = Bitset<64>(std::array<uint64_t, 1>{0xdeadbeefcafe1234});
+ static_assert(B64.getWord(0) == 0xdeadbeefcafe1234);
+
+ // Multi-word bitset.
+ constexpr auto B128 = Bitset<128>(
+ std::array<uint64_t, 2>{0x1111222233334444, 0xaaaabbbbccccdddd});
+ static_assert(B128.getWord(0) == 0x1111222233334444);
+ static_assert(B128.getWord(1) == 0xaaaabbbbccccdddd);
+
+ // Partial last word — high bits should be masked off.
+ constexpr auto B96 = Bitset<96>(
+ std::array<uint64_t, 2>{0xffffffffffffffff, 0xffffffffffffffff});
+ static_assert(B96.getWord(0) == 0xffffffffffffffff);
+ // Only lower 32 bits.
+ static_assert(B96.getWord(1) == 0x00000000ffffffff);
+
+ // Empty bitset.
+ static_assert(Bitset<64>().getWord(0) == 0);
+ static_assert(Bitset<128>().getWord(0) == 0);
+ static_assert(Bitset<128>().getWord(1) == 0);
+}
+
+TEST(BitsetTest, FindLastSet) {
+ // Empty bitset returns -1.
+ static_assert(Bitset<64>().findLastSet() == -1);
+ static_assert(Bitset<128>().findLastSet() == -1);
+
+ // Single bit set.
+ static_assert(Bitset<64>({0}).findLastSet() == 0);
+ static_assert(Bitset<64>({63}).findLastSet() == 63);
+ static_assert(Bitset<64>({31}).findLastSet() == 31);
+ static_assert(Bitset<128>({0}).findLastSet() == 0);
+ static_assert(Bitset<128>({64}).findLastSet() == 64);
+ static_assert(Bitset<128>({127}).findLastSet() == 127);
+
+ // Multiple bits — returns highest.
+ static_assert(Bitset<64>({0, 10, 50}).findLastSet() == 50);
+ static_assert(Bitset<128>({0, 63, 64, 100}).findLastSet() == 100);
+
+ // All bits set.
+ static_assert(Bitset<64>().set().findLastSet() == 63);
+ static_assert(Bitset<128>().set().findLastSet() == 127);
+ static_assert(Bitset<96>().set().findLastSet() == 95);
+
+ // Non-power-of-2 sizes.
+ static_assert(Bitset<33>({32}).findLastSet() == 32);
+ static_assert(Bitset<33>({0, 32}).findLastSet() == 32);
+ static_assert(Bitset<65>({64}).findLastSet() == 64);
+}
+
+TEST(BitsetTest, ShiftMultiWords) {
+ constexpr auto B192 = Bitset<192>({0, 64, 128});
+ static_assert((B192 << 1) == Bitset<192>({1, 65, 129}));
+ static_assert((B192 >> 1) == Bitset<192>({63, 127}));
+ static_assert((B192 << 64) == Bitset<192>({64, 128}));
+ static_assert((B192 >> 64) == Bitset<192>({0, 64}));
+ static_assert((Bitset<192>({63, 127}) << 1) == Bitset<192>({64, 128}));
+ static_assert((Bitset<192>({64, 128}) >> 1) == Bitset<192>({63, 127}));
+}
+
+TEST(BitsetTest, ShiftBoundaryBitShifts) {
+ static_assert((Bitset<128>({1}) << 63) == Bitset<128>({64}));
+ static_assert((Bitset<128>({64}) >> 63) == Bitset<128>({1}));
+ static_assert((Bitset<192>({1, 65}) << 63) == Bitset<192>({64, 128}));
+ // Shift by NumBits - 1.
+ static_assert((Bitset<64>({0}) << 63) == Bitset<64>({63}));
+ static_assert((Bitset<64>({63}) >> 63) == Bitset<64>({0}));
+ static_assert((Bitset<33>({0}) << 32) == Bitset<33>({32}));
+ // Full-width shift of a fully-set bitset loses exactly one bit.
+ static_assert((Bitset<128>().set() << 1).count() == 127);
+ static_assert((Bitset<128>().set() >> 1).count() == 127);
+ static_assert((Bitset<100>().set() >> 1).count() == 99);
+}
+
+TEST(BitsetTest, ShiftExcessAmount) {
+ static_assert((Bitset<64>().set() << 65).none());
+ static_assert((Bitset<64>().set() >> 200).none());
+ static_assert((Bitset<33>({0, 10, 32}) << 1000).none());
+ static_assert((Bitset<128>({0, 127}) >> 1000).none());
+ static_assert((Bitset<192>().set() << 193).none());
+}
+
+TEST(BitsetTest, ShiftAssignReturnsReference) {
+ constexpr Bitset<64> L = [] {
+ Bitset<64> X({0});
+ (X <<= 3) <<= 2;
+ return X;
+ }();
+ static_assert(L == Bitset<64>({5}));
+
+ constexpr Bitset<128> R = [] {
+ Bitset<128> X({100});
+ (X >>= 30) >>= 10;
+ return X;
+ }();
+ static_assert(R == Bitset<128>({60}));
+}
+
+TEST(BitsetTest, GetWordConsistencyWithTest) {
+ // For every set bit, getWord must report it in the expected 64-bit word.
+ constexpr auto B100 = Bitset<100>({0, 50, 64, 99});
+ static_assert((B100.getWord(0) & 1) != 0);
+ static_assert((B100.getWord(0) & (uint64_t(1) << 50)) != 0);
+ static_assert((B100.getWord(1) & 1) != 0);
+ static_assert((B100.getWord(1) & (uint64_t(1) << 35)) != 0);
+}
+
+TEST(BitsetTest, GetWordAfterMutation) {
+ // getWord reflects subsequent set / shift.
+ constexpr auto B = [] {
+ Bitset<128> X;
+ X.set(5).set(70);
+ return X;
+ }();
+ static_assert(B.getWord(0) == (uint64_t(1) << 5));
+ static_assert(B.getWord(1) == (uint64_t(1) << 6));
+
+ constexpr auto Shifted = Bitset<128>({5}) << 64;
+ static_assert(Shifted.getWord(0) == 0);
+ static_assert(Shifted.getWord(1) == (uint64_t(1) << 5));
+}
+
+TEST(BitsetTest, GetNumWordsMoreWidths) {
+ static_assert(Bitset<2>::getNumWords() == 1);
+ static_assert(Bitset<192>::getNumWords() == 3);
+ static_assert(Bitset<193>::getNumWords() == 4);
+ static_assert(Bitset<256>::getNumWords() == 4);
+}
+
+TEST(BitsetTest, FindLastSetSmallWidths) {
+ static_assert(Bitset<1>().findLastSet() == -1);
+ static_assert(Bitset<1>({0}).findLastSet() == 0);
+ static_assert(Bitset<2>({0, 1}).findLastSet() == 1);
+ static_assert(Bitset<32>({31}).findLastSet() == 31);
+ static_assert(Bitset<32>().set().findLastSet() == 31);
+}
+
+TEST(BitsetTest, FindLastSetMultiWordScan) {
+ static_assert(Bitset<192>({70}).findLastSet() == 70);
+ static_assert(Bitset<192>({64, 70, 127}).findLastSet() == 127);
+ static_assert(Bitset<192>({3}).findLastSet() == 3);
+ static_assert(Bitset<100>({99}).findLastSet() == 99);
+}
+
+TEST(BitsetTest, FindLastSetAfterMutation) {
+ constexpr auto A = Bitset<128>({0, 50, 100}).reset(100);
+ static_assert(A.findLastSet() == 50);
+
+ constexpr auto B = Bitset<64>({10}) << 20;
+ static_assert(B.findLastSet() == 30);
+
+ constexpr auto C = Bitset<64>({63}) >> 10;
+ static_assert(C.findLastSet() == 53);
+
+ constexpr auto D = Bitset<64>({63}) << 1;
+ static_assert(D.findLastSet() == -1);
+}
+
} // namespace
|
7f36829 to
657c2db
Compare
|
Adding @arsenm and @s-barannikov for viz. Thank you! |
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
|
Failure related to this: #193558. Rebasing and testing again. |
657c2db to
7cbc4b4
Compare
| } | ||
|
|
||
| /// Return the number of 64-bit words needed to hold all bits. | ||
| static constexpr unsigned getNumWords() { return (NumBits + 63) / 64; } |
There was a problem hiding this comment.
I think this question is related to the one below, and I tried to reply to both here.
| } | ||
|
|
||
| /// Return the I-th 64-bit word of the bitset, from least significant to most. | ||
| constexpr uint64_t getWord(unsigned I) const { |
There was a problem hiding this comment.
I feel like the naming is confusing because getWord would suggest that we just get the I-th word from the array, which would not require any operation other than Bits[I] and wouldn't need bitsize inspection. I would expect the return value of a getWord routine to be BitWord
There was a problem hiding this comment.
BitwordBits is architecture-dependent -- it could be 32-bit or 64-bit. Instead of making both getWord and getNumWords thin getters, my intention was that we provide a normalized 64-bit view to the external accessors. This way, we are not exposing every implementation detail to the consumers of Bitset. On the other hand, I think the names are indeed a little bit confusing. Would it make more sense if I rename them to be getWord64() and getNumWords64()?
There was a problem hiding this comment.
yes, I guess that would make more sense
There was a problem hiding this comment.
Sure, I have changed the function names accordingly.
| /// Return the I-th 64-bit word of the bitset, from least significant to most. | ||
| constexpr uint64_t getWord(unsigned I) const { | ||
| if constexpr (BitwordBits == 64) { | ||
| return Bits[I]; |
There was a problem hiding this comment.
could also do some check or workaround for index-out-of-bounds, like you did for the 32bit case
There was a problem hiding this comment.
Sure, thanks for catching this! I added an assertion to check index-out-of-bounds.
7cbc4b4 to
27bc031
Compare
27bc031 to
e9a137f
Compare
This PR is split out from #191757 per reviewer request. It has the following changes to
llvm::Bitset<N>:operator<</<<=/>>/>>=,getNumWords(),getWord(), andfindLastSet().std::array<>constructor from protected to public and explicit.A follow-up PR will use these to re-implement
LaneBitmaskas allvm::Bitsetwrapper.The unit test in the PR is largely generated by LLMs. I have reviewed it and manually applied changes to cover more edge cases.